import pandas as pd
from rdkit import Chem

from my_mol_tree import MolTree
import csv
import os
import sys

current_dir = os.path.dirname(__file__)
sys.path.append(current_dir)

import rdkit
import rdkit.Chem as Chem
import numpy as np
import copy
from chemutils import get_clique_mol, tree_decomp, brics_decomp, get_mol, get_smiles, set_atommap, \
    enum_assemble, decode_stereo
import os
import sys

current_dir = os.path.dirname(__file__)
sys.path.append(current_dir)


def get_slots(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]


class Vocab(object):
    def __init__(self, smiles_list):
        self.vocab = smiles_list
        self.vmap = {x: i for i, x in enumerate(self.vocab)}
        self.slots = [get_slots(smiles) for smiles in self.vocab]

    def get_index(self, smiles):
        return self.vmap[smiles]

    def get_smiles(self, idx):
        return self.vocab[idx]

    def get_slots(self, idx):
        return copy.deepcopy(self.slots[idx])

    def size(self):
        return len(self.vocab)


class MolTreeNode(object):

    def __init__(self, smiles, clique=[]):
        self.smiles = smiles
        self.mol = get_mol(self.smiles)
        # self.mol = cmol

        self.clique = [x for x in clique]  # copy
        self.neighbors = []

    def add_neighbor(self, nei_node):
        self.neighbors.append(nei_node)

    def recover(self, original_mol):
        clique = []
        clique.extend(self.clique)
        if not self.is_leaf:
            for cidx in self.clique:
                original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
        for nei_node in self.neighbors:
            clique.extend(nei_node.clique)
            if nei_node.is_leaf:  # Leaf node, no need to mark
                continue
            for cidx in nei_node.clique:
                # allow singleton node override the atom mapping
                if cidx not in self.clique or len(nei_node.clique) == 1:
                    atom = original_mol.GetAtomWithIdx(cidx)
                    atom.SetAtomMapNum(nei_node.nid)

        clique = list(set(clique))
        label_mol = get_clique_mol(original_mol, clique)
        self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
        self.label_mol = get_mol(self.label)

        for cidx in clique:
            original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)

        return self.label

    def assemble(self):
        neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
        neighbors = sorted(neighbors, key=lambda x: x.mol.GetNumAtoms(), reverse=True)
        singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cands = enum_assemble(self, neighbors)
        if len(cands) > 0:
            self.cands, self.cand_mols, _ = zip(*cands)
            self.cands = list(self.cands)
            self.cand_mols = list(self.cand_mols)
        else:
            self.cands = []
            self.cand_mols = []


class MolTree(object):

    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = get_mol(smiles)
        self.cliques = []
        self.edges = []
        '''
        #Stereo Generation
        mol = Chem.MolFromSmiles(smiles)
        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.smiles2D = Chem.MolToSmiles(mol)
        self.stereo_cands = decode_stereo(self.smiles2D)
        '''

        self.cliques, self.edges = brics_decomp(self.mol)
        if len(self.edges) <= 1:
            self.cliques, self.edges = tree_decomp(self.mol)
        # self.nodes = []
        # root = 0
        # for i, c in enumerate(self.cliques):
        #     cmol = get_clique_mol(self.mol, c)
        #     node = MolTreeNode(get_smiles(cmol), c)
        #     self.nodes.append(node)
        #     if min(c) == 0:
        #         root = i
        #
        # for x, y in self.edges:
        #     self.nodes[x].add_neighbor(self.nodes[y])
        #     self.nodes[y].add_neighbor(self.nodes[x])
        #
        # if root > 0:
        #     self.nodes[0], self.nodes[root] = self.nodes[root], self.nodes[0]
        #
        # for i, node in enumerate(self.nodes):
        #     node.nid = i + 1
        #     if len(node.neighbors) > 1:  # Leaf node mol is not marked
        #         set_atommap(node.mol, node.nid)
        #     node.is_leaf = (len(node.neighbors) == 1)

    def get_clique_dict(self):
        return self.cliques, self.edges
        # return brics_decomp(self.mol)

    def size(self):
        return len(self.nodes)

    def recover(self):
        for node in self.nodes:
            node.recover(self.mol)

    def assemble(self):
        for node in self.nodes:
            node.assemble()


import csv


def getCliqueDict(input_path, output_path):
    lg = rdkit.RDLogger.logger()
    lg.setLevel(rdkit.RDLogger.CRITICAL)

    cset = set()
    counts = {}

    print("start")
    with open(input_path, 'r') as file:
        reader = csv.reader(file)
        for row in reader:
            # print(row[0])
            mol = MolTree(row[0])
            for c in mol.nodes:
                cset.add(c.smiles)
                if c.smiles not in counts:
                    counts[c.smiles] = 1
                else:
                    counts[c.smiles] += 1

    print("Preprocessing Completed!")
    clique_list = list(cset)
    with open(output_path, 'w') as file:
        for c in clique_list:
            file.write(c)
            file.write('\n')


def csv2txt(in_path, out_path):
    with open(in_path, 'r') as csv_file:
        csv_reader = csv.reader(csv_file)
        with open(out_path, 'w') as txt_file:
            for row in csv_reader:
                txt_file.write(row[0] + '\n')


def write_book(in_path, cli_path, edge_path):
    print("start")
    files = open(cli_path, 'w')
    edges_file = open(edge_path, 'w', encoding='utf-8')
    data_csv = []
    with open(in_path, 'r') as csv_file:
        csv_reader = csv.reader(csv_file)
        # next(csv_reader)
        for index, row in enumerate(csv_reader):
            # # row_pro = [i for i in row[:-1]]
            line = row[0]
            try:
                mol = MolTree(line)
                # data_csv.append(row_pro)
            except:
                print(index)
                continue
            # mol=MolTree(line)
            clique, edges = mol.get_clique_dict()
            if len(clique) == 0 & len(edges) == 0:
                mol1 = Chem.MolFromSmiles(line)
                n_atoms = mol1.GetNumAtoms()
                list = [[]]
                for i in range(n_atoms):
                    list[0].append(i)
                clique = list
                print(index)
                print(clique)
            edges_file.write(str(edges) + '\n')
            for index1, cli in enumerate(clique):
                s = str(index) + ' ' + str(index1) + ' /' + str(cli) + '\n'
                files.write(s)
    with open("../dataset/zinc_standard_agent/raw/newfile_data.csv", 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerows(data_csv)
    file.close()
    files.close()


def read_book(cli_path, edge_path):
    file = open(cli_path, 'r', encoding='utf-8')
    data = file.readlines()
    all_edge = []
    edges = open(edge_path, 'r', encoding='utf-8')
    for e in edges:
        s = e.strip('\n')
        lst = eval(s)
        # 将元组转换为二维列表
        lst_2d = [[x, y] for x, y in lst]
        all_edge.append(lst_2d)
    clique = []
    cliques = []
    mol_index = 0
    for index, cli in enumerate(data):
        cli = cli.strip('\n')
        index_cli = (cli.split(' /'))
        indexs = index_cli[0]
        mol, cli_index = indexs.split(' ')
        mol, cli_index = int(mol), int(cli_index)
        if mol != mol_index:
            cliques.append(clique)
            clique = []
            c = eval(index_cli[1])
            clique.append(c)
            mol_index = mol
        else:
            c = eval(index_cli[1])
            clique.append(c)
    cliques.append(clique)
    return cliques, all_edge


if __name__ == '__main__':
    # write_book('../data/clique_test.txt','../data/edges_test.txt')
    # all_cliques, all_edge = read_book('../data/clique_test.txt', '../data/edges_test.txt')
    # print(all_cliques)
    # print(x)
    class_name = "mutag"
    root_path = "../dataset/" + class_name + "/raw/"
    # input_df = pd.read_csv(root_path, sep=' ', header=None, names=['smiles', 'id', 'label'])
    # smiles_list = input_df['smiles']
    # with open("../dataset/" + class_name + "/raw/" + "smiles.txt", 'w') as txt_file:
    #     for row in smiles_list:
    #         txt_file.write(row + '\n')
    # csv2txt( "../dataset/" + class_name + "/raw/"+ "smiles.csv",  "../dataset/" + class_name + "/raw/" + "smiles.txt")
    write_book("../dataset/" + class_name + "/raw/" + "smiles.csv","../dataset/" + class_name + "/raw/" + "clique_dict.txt", "../dataset/" + class_name + "/raw/" + "edge_dict.txt")
